import os
import openai
import random
import numpy as np
import json
import jsonlines
import time
from tqdm import tqdm
from rank_bm25 import BM25Okapi

# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
openai.api_key = OPENAI_API_KEY

start_prompt = '''
You need to pick up one sentences from two captions most related to the given cue sentence. Here are five examples.
'''

def ask_gpt4(question):
    messages=[{"role": "user", "content": question}]
    while True:
        try:
            response =  openai.ChatCompletion.create(
                            model="gpt-4",
                            max_tokens=1000,
                            temperature=1.2,
                            messages = messages)
            answer = response["choices"][0]["message"]["content"]
            return answer
        except openai.error.RateLimitError: # Rate limit exceeded
            time.sleep(0.2)
        except openai.error.Timeout: # Rate limit exceeded
            time.sleep(0.2)
        except openai.error.OpenAIError:
            raise Exception("Sorry, a problem happened")
        
def read_jsonline(sample_file):
    samples = []
    for line in sample_file.iter():
        sample = "The selected number of case is one, and the captions are %s. If the cue sentence is %s, the corresponding correct answer %s; elseif the cue sentence is %s, the corresponding correct labels is %s. " %(str(line['labels']), line['predictions'][0], line['labels'][0], line['predictions'][1], line['labels'][0])
        samples.append(sample)
    return samples

if __name__=="__main__":
    dataset = jsonlines.open('./data/winoground/similar_icl/test.jsonl')
    sample_file = jsonlines.open('./data/winoground/similar_icl/train.jsonl')
    corpus = read_jsonline(sample_file)
    tokenized_corpus = [doc.split(" ") for doc in corpus]   
    bm25 = BM25Okapi(tokenized_corpus)
    with tqdm(desc='Process', unit='it', total=41) as pbar: 
        with open('./gpt4_ans/winoground/similar_icl/text_score/test.jsonl','a') as outfile:
            num = 1 # this used to fix connection error problem 
            for line in dataset.iter():
                if num > 0: #start from 1 to reord, once code stop in middle, change number to the stop point
                    captions = line['labels']
                    cue1 = line['predictions'][0]
                    query = "Now choose the one sentence most related to the cue sentence: %s from captions: %s." % (cue1, str(captions))
                    tokenized_query = query.split(" ")
                    doc_scores = bm25.get_scores(tokenized_query)
                    scored_sentences = list(zip(corpus, doc_scores))
                    sorted_scored_sentences = sorted(scored_sentences, key=lambda x: x[1], reverse=True)
                    top_5_sentences = [sentence[0] for sentence in sorted_scored_sentences[:5]]
                    samples_prompt = ''''''
                    for sentence in top_5_sentences:
                        sample_prompt = '''%s''' %(sentence)
                        samples_prompt = f'''{samples_prompt}{sample_prompt}'''
                    Question_part1 = '''\nNow choose one sentence most related to the cue sentence: %s, from captions: %s. ''' % (cue1, str(captions))
                    Question_part2 = '''Directly return the one sentences as answer (If cant just make a guess).'''
                    content = f'''{start_prompt}{samples_prompt}{Question_part1}{Question_part2}'''
                    answer1 = ask_gpt4(content)

                    cue2 = line['predictions'][1]
                    query = "Now choose the one sentence most related to the cue sentence: %s from captions: %s." % (cue2, str(captions))
                    tokenized_query = query.split(" ")
                    doc_scores = bm25.get_scores(tokenized_query)
                    scored_sentences = list(zip(corpus, doc_scores))
                    sorted_scored_sentences = sorted(scored_sentences, key=lambda x: x[1], reverse=True)
                    top_5_sentences = [sentence[0] for sentence in sorted_scored_sentences[:5]]
                    samples_prompt = ''''''
                    for sentence in top_5_sentences:
                        sample_prompt = '''%s''' %(sentence)
                        samples_prompt = f'''{samples_prompt}{sample_prompt}'''
                    Question_part1 = '''\nNow choose one sentence most related to the cue sentence: %s, from captions: %s. ''' % (cue2, str(captions))
                    Question_part2 = '''Directly return the one sentences as answer (If cant just make a guess).'''
                    content = f'''{start_prompt}{samples_prompt}{Question_part1}{Question_part2}'''
                    answer2 = ask_gpt4(content)
                    line['gpt4_score_1'] = answer1
                    line['gpt4_score_2'] = answer2
                    json.dump(line, outfile)
                    outfile.write('\n')
                pbar.update()
                num = num + 1
    num = 0
    dataset = jsonlines.open('./data/winoground/similar_icl/test.jsonl')
    sample_file = jsonlines.open('./data/winoground/similar_icl/train.jsonl')
    corpus = read_jsonline(sample_file)
    tokenized_corpus = [doc.split(" ") for doc in corpus]   
    bm25 = BM25Okapi(tokenized_corpus)
    with tqdm(desc='Process', unit='it', total=41) as pbar: 
        with open('./gpt4_ans/winoground/similar_icl/image_score/test.jsonl','a') as outfile:
            num = 1 # this used to fix connection error problem 
            for line in dataset.iter():
                if num > 0: #start from 1 to reord, once code stop in middle, change number to the stop point
                    captions = line['predictions']
                    cue1 = line['labels'][0]
                    query = "Now choose the one sentence most related to the cue sentence: %s from captions: %s." % (cue1, str(captions))
                    tokenized_query = query.split(" ")
                    doc_scores = bm25.get_scores(tokenized_query)
                    scored_sentences = list(zip(corpus, doc_scores))
                    sorted_scored_sentences = sorted(scored_sentences, key=lambda x: x[1], reverse=True)
                    top_5_sentences = [sentence[0] for sentence in sorted_scored_sentences[:5]]
                    samples_prompt = ''''''
                    for sentence in top_5_sentences:
                        sample_prompt = '''%s''' %(sentence)
                        samples_prompt = f'''{samples_prompt}{sample_prompt}'''
                    Question_part1 = '''\nNow choose one sentence most related to the cue sentence: %s, from captions: %s. ''' % (cue1, str(captions))
                    Question_part2 = '''Directly return the one sentences as answer (If cant just make a guess).'''
                    content = f'''{start_prompt}{samples_prompt}{Question_part1}{Question_part2}'''
                    answer1 = ask_gpt4(content)

                    cue2 = line['labels'][1]
                    query = "Now choose the one sentence most related to the cue sentence: %s from captions: %s." % (cue2, str(captions))
                    tokenized_query = query.split(" ")
                    doc_scores = bm25.get_scores(tokenized_query)
                    scored_sentences = list(zip(corpus, doc_scores))
                    sorted_scored_sentences = sorted(scored_sentences, key=lambda x: x[1], reverse=True)
                    top_5_sentences = [sentence[0] for sentence in sorted_scored_sentences[:5]]
                    samples_prompt = ''''''
                    for sentence in top_5_sentences:
                        sample_prompt = '''%s''' %(sentence)
                        samples_prompt = f'''{samples_prompt}{sample_prompt}'''
                    Question_part1 = '''\nNow choose one sentence most related to the cue sentence: %s, from captions: %s. ''' % (cue2, str(captions))
                    Question_part2 = '''Directly return the one sentences as answer (If cant just make a guess).'''
                    content = f'''{start_prompt}{samples_prompt}{Question_part1}{Question_part2}'''
                    answer2 = ask_gpt4(content)
                    line['gpt4_score_1'] = answer1
                    line['gpt4_score_2'] = answer2
                    json.dump(line, outfile)
                    outfile.write('\n')
                pbar.update()
                num = num + 1



